{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# execute this in command line before initiating the notebook: \n",
    "#    pip install -U pip\n",
    "#    pip install -U ipywidgets==7.5.1\n",
    "#    jupyter nbextension enable --py widgetsnbextension\n",
    "\n",
    "# pip install with locked versions\n",
    "! pip install -U torch==1.5.1\n",
    "! pip install -U torchvision==0.6.1\n",
    "! pip install -U numpy==1.18.4\n",
    "! pip install -U trains>=0.16.1\n",
    "! pip install -U tensorboard==2.2.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "\n",
    "import torchvision.datasets as datasets\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "from trains import Task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "task = Task.init(project_name='Image Example', task_name='image classification CIFAR10')\n",
    "configuration_dict = {'number_of_epochs': 3, 'batch_size': 4, 'dropout': 0.25, 'base_lr': 0.001}\n",
    "configuration_dict = task.connect(configuration_dict)  # enabling configuration override by trains\n",
    "print(configuration_dict)  # printing actual configuration (after override in remote mode)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = transforms.Compose([transforms.ToTensor()])\n",
    "\n",
    "trainset = datasets.CIFAR10(root='./data', train=True,\n",
    "                                        download=True, transform=transform)\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=configuration_dict.get('batch_size', 4),\n",
    "                                          shuffle=True, num_workers=2)\n",
    "\n",
    "testset = datasets.CIFAR10(root='./data', train=False,\n",
    "                                       download=True, transform=transform)\n",
    "testloader = torch.utils.data.DataLoader(testset, batch_size=configuration_dict.get('batch_size', 4),\n",
    "                                         shuffle=False, num_workers=2)\n",
    "\n",
    "classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')\n",
    "\n",
    "device = torch.cuda.current_device() if torch.cuda.is_available() else torch.device('cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(3, 6, 3)\n",
    "        self.conv2 = nn.Conv2d(6, 16, 3)\n",
    "        self.pool = nn.MaxPool2d(2, 2)\n",
    "        self.fc1 = nn.Linear(16 * 6 * 6, 120)\n",
    "        self.fc2 = nn.Linear(120, 84)\n",
    "        self.dorpout = nn.Dropout(p=configuration_dict.get('dropout', 0.25))\n",
    "        self.fc3 = nn.Linear(84, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.pool(F.relu(self.conv1(x)))\n",
    "        x = self.pool(F.relu(self.conv2(x)))\n",
    "        x = x.view(-1, 16 * 6 * 6)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.relu(self.fc2(x))\n",
    "        x = self.fc3(self.dorpout(x))\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = Net().to(device)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.SGD(net.parameters(), lr=configuration_dict.get('base_lr', 0.001), momentum=0.9)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tensorboard_writer = SummaryWriter('./tensorboard_logs')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_model(test_dataloader, iteration):\n",
    "    class_correct = list(0. for i in range(10))\n",
    "    class_total = list(0. for i in range(10))\n",
    "    with torch.no_grad():\n",
    "        for j, data in enumerate(test_dataloader, 1):\n",
    "            images, labels = data\n",
    "            images = images.to(device)\n",
    "            labels = labels.to(device)\n",
    "            \n",
    "            outputs = net(images)\n",
    "            _, predicted = torch.max(outputs, 1)\n",
    "            c = (predicted == labels).squeeze()\n",
    "            for i in range(len(images)):\n",
    "                label = labels[i].item()\n",
    "                class_correct[label] += c[i].item()\n",
    "                class_total[label] += 1\n",
    "            \n",
    "            if j % 500 == 0:    # report debug image every 500 mini-batches\n",
    "                for n, (img, pred, label) in enumerate(zip(images, predicted, labels)):\n",
    "                    tensorboard_writer.add_image(\"testing/{}-{}_GT_{}_pred_{}\"\n",
    "                                                 .format(j, n, classes[label], classes[pred]), img, iteration)\n",
    "\n",
    "    for i in range(len(classes)):\n",
    "        class_accuracy = 100 * class_correct[i] / class_total[i]\n",
    "        print('[Iteration {}] Accuracy of {} : {}%'.format(iteration, classes[i], class_accuracy))\n",
    "        tensorboard_writer.add_scalar('accuracy per class/{}'.format(classes[i]), class_accuracy, iteration)\n",
    "\n",
    "    total_accuracy = 100 * sum(class_correct)/sum(class_total)\n",
    "    print('[Iteration {}] Accuracy on the {} test images: {}%\\n'.format(iteration, sum(class_total), total_accuracy))\n",
    "    tensorboard_writer.add_scalar('accuracy/total', total_accuracy, iteration)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for epoch in range(configuration_dict.get('number_of_epochs', 3)):  # loop over the dataset multiple times\n",
    "\n",
    "    running_loss = 0.0\n",
    "    for i, data in enumerate(trainloader, 1):\n",
    "        # get the inputs; data is a list of [inputs, labels]\n",
    "        inputs, labels = data\n",
    "        inputs = inputs.to(device)\n",
    "        labels = labels.to(device)\n",
    "\n",
    "        # zero the parameter gradients\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        # forward + backward + optimize\n",
    "        outputs = net(inputs)\n",
    "        loss = criterion(outputs, labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        # print statistics\n",
    "        running_loss += loss.item()\n",
    "        \n",
    "        iteration = epoch * len(trainloader) + i\n",
    "        if i % 2000 == 0:    # report loss every 2000 mini-batches\n",
    "            print('[Epoch %d, Iteration %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 2000))\n",
    "            tensorboard_writer.add_scalar('training loss', running_loss / 2000, iteration)\n",
    "            running_loss = 0.0\n",
    "    \n",
    "    test_model(testloader, iteration)\n",
    "\n",
    "print('Finished Training')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "PATH = './cifar_net.pth'\n",
    "torch.save(net.state_dict(), PATH)\n",
    "tensorboard_writer.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('Task ID number is: {}'.format(task.id))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
